{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import gc\n",
    "import requests\n",
    "from safetensors.torch import load_file\n",
    "from diffusers import FluxPipeline\n",
    "from pipeline_flux_img2img import FluxImg2ImgPipeline\n",
    "from transformer_flux import FluxTransformer2DModel\n",
    "from attention_processor import LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n",
    "device = torch.device('cuda')\n",
    "dtype = torch.bfloat16\n",
    "pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=dtype).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"enchanted forest, glowing plants, towering ancient trees, a mystical girl, magical aura, \" \\\n",
    "         \"fantasy style, vibrant colors, ethereal lighting, bokeh effect, ultra-detailed, painterly, ultra HD, 8K, \" \\\n",
    "         \"soft glowing lights, mist and fog, otherworldly ambiance, glowing mushrooms, sparkling particles\"\n",
    "height = 512\n",
    "width = 1024\n",
    "image = pipe(\n",
    "    prompt,\n",
    "    height=height,\n",
    "    width=width,\n",
    "    guidance_scale=3.5,\n",
    "    num_inference_steps=20,\n",
    "    max_sequence_length=512,\n",
    "    generator=torch.Generator(\"cpu\").manual_seed(0)\n",
    ").images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del pipe\n",
    "torch.cuda.empty_cache()\n",
    "gc.collect()\n",
    "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n",
    "pipe = FluxImg2ImgPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=torch.bfloat16)\n",
    "pipe.transformer = transformer\n",
    "pipe.scheduler.config.use_dynamic_shifting = False\n",
    "pipe.scheduler.config.time_shift = 10\n",
    "pipe.vae.enable_tiling()\n",
    "pipe = pipe.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "height = 2048\n",
    "width = 4096\n",
    "down_factor, window_size = 4, 8\n",
    "# Supported Configurations:\n",
    "# down_factor, window_size = 1, 8\n",
    "# down_factor, window_size = 1, 16\n",
    "# down_factor, window_size = 1, 32\n",
    "# down_factor, window_size = 4, 16\n",
    "# down_factor, window_size = 4, 8\n",
    "if down_factor == 1:\n",
    "    init_local_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, device=device)\n",
    "    attn_processors = {}\n",
    "    for k in pipe.transformer.attn_processors.keys():\n",
    "        attn_processors[k] = LocalFlexAttnProcessor()\n",
    "else:\n",
    "    init_local_downsample_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, down_factor=down_factor, device=device)\n",
    "    attn_processors = {}\n",
    "    for k in pipe.transformer.attn_processors.keys():\n",
    "        attn_processors[k] = LocalDownsampleFlexAttnProcessor(down_factor=down_factor).to(device, dtype)\n",
    "pipe.transformer.set_attn_processor(attn_processors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists('ckpt'):\n",
    "    os.mkdir('ckpt')\n",
    "if down_factor == 1:\n",
    "    if not os.path.exists(f'ckpt/clear_local_{window_size}.safetensors'):\n",
    "        print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}.safetensors')\n",
    "        response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}.safetensors\")\n",
    "        response.raise_for_status()\n",
    "        with open(f'ckpt/clear_local_{window_size}.safetensors', 'wb') as f:\n",
    "            f.write(response.content)\n",
    "    state_dict = load_file(f'ckpt/clear_local_{window_size}.safetensors')\n",
    "else:\n",
    "    if not os.path.exists(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors'):\n",
    "        print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n",
    "        response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}_down_{down_factor}.safetensors\")\n",
    "        response.raise_for_status()\n",
    "        with open(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors', 'wb') as f:\n",
    "            f.write(response.content)\n",
    "    state_dict = load_file(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n",
    "\n",
    "missing_keys, unexpected_keys = pipe.transformer.load_state_dict(state_dict, strict=False)\n",
    "\n",
    "missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or \n",
    "                                      '.attn.to_k.' in p or \n",
    "                                      '.attn.to_v.' in p or \n",
    "                                      '.attn.to_out.' in p or \n",
    "                                      'spatial_weight' in p), missing_keys))\n",
    "\n",
    "if len(missing_keys) != 0 or len(unexpected_keys) != 0:\n",
    "    print(\n",
    "        f\"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}\"\n",
    "        f\" and missing keys: {missing_keys}.\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "strength = 0.7\n",
    "image_hr = pipe(prompt=prompt,\n",
    "                image=image.resize((width, height)),\n",
    "                strength=strength,\n",
    "                num_inference_steps=20, \n",
    "                guidance_scale=7.5, \n",
    "                height=height,\n",
    "                width=width,\n",
    "                ntk_factor=10,\n",
    "                proportional_attention=True,\n",
    "                generator=torch.Generator(\"cpu\").manual_seed(0)\n",
    "                ).images[0]\n",
    "image_hr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt25",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
